# -*- coding: utf-8 -*-
"""
Created on Tue Aug 20 09:47:06 2024

@author: xinyuan.wei
"""

import os
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import OneHotEncoder
import numpy as np

# Stand age set up
stage_min = 0
stage_max = 140
step = 10  # Step size for stand age intervals

# Get the current working directory
current_directory = os.getcwd()

# Get the parent directory
parent_directory = os.path.dirname(current_directory)
file_path = os.path.join(parent_directory, 'NE_Plot_Composition.csv')
data = pd.read_csv(file_path)

# Filter the data based on the given conditions
filtered_data = data[(data['CONDID'] == 1) & 
                     (data['CONDPROP_UNADJ'] >= 0.9) &
                     (data['eAG'] >= 1) &
                     (data['sindex'] >= 0) &
                     (data['evenness'] >= 0) &
                     (data['dominant_area_species'].notna()) &
                     (data['STDAGE'] > stage_min) &
                     (data['STDAGE'] <= stage_max)].copy()

# Initialize a dictionary to store feature importances for each stand age interval
feature_importances = {
    'stand_age': [],
    'sindex': [],
    'evenness': [],
    'dominant_area_species': []
}

# Loop through the stand age intervals
for start_age in range(stage_min, stage_max, step):
    end_age = start_age + step
    
    # Filter data for the current stand age interval
    age_filtered_data = filtered_data[(filtered_data['STDAGE'] > start_age) & 
                                      (filtered_data['STDAGE'] <= end_age)].copy()
    
    if len(age_filtered_data) < 10:
        # Skip this interval if there's not enough data
        continue
    
    # Prepare the features and target variable
    X = age_filtered_data[['sindex', 'evenness', 'dominant_area_species']]
    y = age_filtered_data['eAG']
    
    # One-hot encode the dominant species code
    encoder = OneHotEncoder(sparse=False)
    species_encoded = encoder.fit_transform(X[['dominant_area_species']])
    
    # Concatenate the encoded species with the other features
    X_encoded = np.concatenate([X[['sindex', 'evenness']].values, species_encoded], axis=1)
    
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X_encoded, y, test_size=0.3, random_state=42)
    
    # Initialize and train the Random Forest Regressor
    rf = RandomForestRegressor(n_estimators=100, random_state=42)
    rf.fit(X_train, y_train)
    
    # Calculate feature importances
    importances = rf.feature_importances_
    
    # Extract the importance of sindex, evenness, and dominant_area_species
    sindex_importance = importances[0]
    evenness_importance = importances[1]
    dominant_species_importance = np.sum(importances[2:])
    
    # Store the feature importances
    feature_importances['stand_age'].append(f"{start_age}-{end_age}")
    feature_importances['sindex'].append(sindex_importance)
    feature_importances['evenness'].append(evenness_importance)
    feature_importances['dominant_area_species'].append(dominant_species_importance)
    
    # Optional: Evaluate model performance (can be removed if not needed)
    y_pred = rf.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    print(f"Stand age {start_age}-{end_age}: MSE = {mse}")

# Convert the feature importances into a DataFrame for easy plotting
importances_df = pd.DataFrame(feature_importances)

# Save the results to a CSV file
output_file_path = os.path.join(parent_directory, 'feature_importances_by_age.csv')
importances_df.to_csv(output_file_path, index=False)
print(f"Feature importances saved to {output_file_path}")

# Plot the feature importances for each stand age interval
plt.figure(figsize=(10, 6))
plt.plot(importances_df['stand_age'], importances_df['sindex'], label='sindex', color='blue', marker='o')
plt.plot(importances_df['stand_age'], importances_df['evenness'], label='evenness', color='green', marker='o')
plt.plot(importances_df['stand_age'], importances_df['dominant_area_species'], label='dominant_area_species', color='red', marker='o')

plt.title('Feature Importance Across Stand Age Intervals')
plt.xlabel('Stand Age Interval (Years)')
plt.ylabel('Feature Importance')
plt.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

